week 7: multilevel models

multilevel adventures

divergent transitions

From McElreath:

Recall that HMC simulates the frictionless flow of a particle on a surface. In any given transition, which is just a single flick of the particle, the total energy at the start should be equal to the total energy at the end. That’s how energy in a closed system works. And in a purely mathematical system, the energy is always conserved correctly. It’s just a fact about the physics.

But in a numerical system, it might not be. Sometimes the total energy is not the same at the end as it was at the start. In these cases, the energy is divergent. How can this happen? It tends to happen when the posterior distribution is very steep in some region of parameter space. Steep changes in probability are hard for a discrete physics simulation to follow. When that happens, the algorithm notices by comparing the energy at the start to the energy at the end. When they don’t match, it indicates numerical problems exploring that part of the posterior distribution.

centered parameterization

In his lecture, McElreath uses CENTERED PARAMETERIZATION to demonstrate divergent transitions. A very simple example:

\[\begin{align*} y_i &\sim \text{Normal}(\mu_i, \sigma) \\ \mu_i &\sim \text{Normal}(0,1) \\ \sigma &\sim \text{Exponential}(1) \\ \end{align*}\]

This expression is centered because one set of priors (the priors for \(y_i\)) are centered around another prior (the prior for \(\mu_i\)). It’s intuitive, but this can cause a lot of problems with Stan, which is probably why McElreath used this for his example. In short, when there is limited data within our groups or the population variance is small, the parameters \(y_i\) and \(mu_i\) become highly correlated. This geometry is challenging for MCMC to sample. (Think of a long and narrow groove, not a bowl, for your Hamiltonian skateboard.)

The way to fix this is by using an uncentered parameterization:

\[\begin{align*} y &= \mu + (\sigma \times z_i) \\ z_i &\sim \text{Normal}(0, 1) \\ \mu &\sim \text{Normal}(0,1) \\ \sigma &\sim \text{Exponential}(1) \\ \end{align*}\]

It’s an important point, except the issues of centered parameterization are so prevalent1, that brms generally doesn’t allow centered parameterization (with some exceptions). So we can’t recreate this divergent transition.

McElreath describes the problem of fertility in Bangladesh as such:

\[\begin{align*} C &\sim \text{Bernoulli}(p_i) \\ \text{logit}(p_i) &= \alpha_{D_{[i]}} \\ \alpha_j &\sim \text{Normal}(\bar{\alpha}, \sigma) \\ \bar{\alpha} &\sim \text{Normal}(0, 1) \\ \sigma &\sim \text{Exponential}(1) \\ \end{align*}\]

But to fit this using brms, we’ll rewrite as:

\[\begin{align*} C &\sim \text{Bernoulli}(p_i) \\ \text{logit}(p_i) &= \alpha + \alpha_{D[i]} \\ \alpha &\sim \text{Normal}(0, 1) \\ \alpha_{D[j]} &\sim \text{Normal}(0, \sigma_{D}) \\ \sigma_{D} &\sim \text{Exponential}(1) \end{align*}\]

\[\begin{align*} C &\sim \text{Bernoulli}(p_i) \\ \text{logit}(p_i) &= \alpha + \alpha_{D[i]} \\ \alpha &\sim \text{Normal}(0, 1) \\ \alpha_{D[j]} &\sim \text{Normal}(0, \sigma_{D}) \\ \sigma_{D} &\sim \text{Exponential}(1) \end{align*}\]

data(bangladesh, package="rethinking")
d <- bangladesh

m1 <- brm(
  data=d,
  family=bernoulli,
  use.contraception ~ 1 + (1 | district),
  prior = c( prior(normal(0, 1), class = Intercept), # alpha bar
             prior(exponential(1), class = sd)),       # sigma

  chains=4, cores=4, iter=2000, warmup=1000,
  seed = 1,
  file = here("files/data/generated_data/m71.1"))
m1
 Family: bernoulli 
  Links: mu = logit 
Formula: use.contraception ~ 1 + (1 | district) 
   Data: d (Number of observations: 1934) 
  Draws: 4 chains, each with iter = 2000; warmup = 1000; thin = 1;
         total post-warmup draws = 4000

Multilevel Hyperparameters:
~district (Number of levels: 60) 
              Estimate Est.Error l-95% CI u-95% CI Rhat Bulk_ESS Tail_ESS
sd(Intercept)     0.52      0.09     0.37     0.70 1.00     1374     1915

Regression Coefficients:
          Estimate Est.Error l-95% CI u-95% CI Rhat Bulk_ESS Tail_ESS
Intercept    -0.54      0.09    -0.72    -0.37 1.00     1998     2342

Draws were sampled using sampling(NUTS). For each parameter, Bulk_ESS
and Tail_ESS are effective sample size measures, and Rhat is the potential
scale reduction factor on split chains (at convergence, Rhat = 1).
gather_draws(m1, b_Intercept, r_district[district, ]) %>% 
  with_groups(c(.variable, district), median_qi, .value)
# A tibble: 61 × 8
# Groups:   .variable, district [61]
   .variable   district  .value .lower  .upper .width .point .interval
   <chr>          <int>   <dbl>  <dbl>   <dbl>  <dbl> <chr>  <chr>    
 1 b_Intercept       NA -0.536  -0.715 -0.369    0.95 median qi       
 2 r_district         1 -0.454  -0.864 -0.0464   0.95 median qi       
 3 r_district         2 -0.0482 -0.757  0.610    0.95 median qi       
 4 r_district         3  0.301  -0.702  1.35     0.95 median qi       
 5 r_district         4  0.343  -0.239  0.964    0.95 median qi       
 6 r_district         5 -0.0297 -0.592  0.510    0.95 median qi       
 7 r_district         6 -0.275  -0.773  0.197    0.95 median qi       
 8 r_district         7 -0.216  -0.945  0.478    0.95 median qi       
 9 r_district         8  0.0236 -0.567  0.603    0.95 median qi       
10 r_district         9 -0.162  -0.866  0.453    0.95 median qi       
# ℹ 51 more rows
Code
gather_draws(m1, b_Intercept, r_district[district, ]) %>% 
  with_groups(c(.variable, district), median_qi, .value) %>% 
  ggplot(aes( x=district, y=.value)) +
  geom_pointinterval( aes(ymin = .lower, ymax = .upper), 
                      alpha=.5) +
  labs(y="District distance from mean") +
  coord_flip()

\[\begin{align*} C &\sim \text{Bernoulli}(p_i) \\ \text{logit}(p_i) &= \alpha + \alpha_{D[i]} + \beta U_i + \beta_{D[i]}U_i \\ \alpha, \beta &\sim \text{Normal}(0, 1) \\ \alpha_{D[j]} &\sim \text{Normal}(0, \sigma_{D}) \\ \beta_{D[j]} &\sim \text{Normal}(0, \tau_{D}) \\ \sigma, \tau &\sim \text{Exponential}(1) \\ \end{align*}\]

m2 <- brm(
  data=d,
  family=bernoulli,
  use.contraception ~ 1 + urban + (1 + urban || district),
  prior = c( prior(normal(0, 1), class = Intercept), 
             prior(normal(0, 1), class = b),
             prior(exponential(1), class = sd)),     

  chains=4, cores=4, iter=2000, warmup=1000,
  seed = 1,
  file = here("files/data/generated_data/m71.2"))

Oops, no divergent transitions.

m2
 Family: bernoulli 
  Links: mu = logit 
Formula: use.contraception ~ 1 + urban + (1 + urban || district) 
   Data: d (Number of observations: 1934) 
  Draws: 4 chains, each with iter = 2000; warmup = 1000; thin = 1;
         total post-warmup draws = 4000

Multilevel Hyperparameters:
~district (Number of levels: 60) 
              Estimate Est.Error l-95% CI u-95% CI Rhat Bulk_ESS Tail_ESS
sd(Intercept)     0.48      0.09     0.32     0.67 1.01     1290     2067
sd(urban)         0.55      0.21     0.11     0.96 1.00      860      912

Regression Coefficients:
          Estimate Est.Error l-95% CI u-95% CI Rhat Bulk_ESS Tail_ESS
Intercept    -0.70      0.09    -0.88    -0.53 1.00     2275     2893
urban         0.63      0.15     0.33     0.92 1.00     2391     2077

Draws were sampled using sampling(NUTS). For each parameter, Bulk_ESS
and Tail_ESS are effective sample size measures, and Rhat is the potential
scale reduction factor on split chains (at convergence, Rhat = 1).

more about divergent transitions

From Gelman et al (2020)

more than one type of cluster

McElreath doesn’t cover this in his video lecture, but this is from the textbook and worth discussing.

data(chimpanzees, package="rethinking")
d <- chimpanzees
str(d)
'data.frame':   504 obs. of  8 variables:
 $ actor       : int  1 1 1 1 1 1 1 1 1 1 ...
 $ recipient   : int  NA NA NA NA NA NA NA NA NA NA ...
 $ condition   : int  0 0 0 0 0 0 0 0 0 0 ...
 $ block       : int  1 1 1 1 1 1 2 2 2 2 ...
 $ trial       : int  2 4 6 8 10 12 14 16 18 20 ...
 $ prosoc_left : int  0 0 1 0 1 1 1 1 0 0 ...
 $ chose_prosoc: int  1 0 0 1 1 1 0 0 1 1 ...
 $ pulled_left : int  0 1 0 0 1 1 0 0 0 0 ...
unique(d$actor)
[1] 1 2 3 4 5 6 7
unique(d$block)
[1] 1 2 3 4 5 6
unique(d$prosoc_left)
[1] 0 1
unique(d$condition)
[1] 0 1

We could model the interaction between condition (presence/absence of another animal) and option (which side is prosocial), but it is more difficult to assign sensible priors to interaction effects. Another option, because we’re working with categorical variables, is to turn our 2x2 into one variable with 4 levels.

d$treatment <- 1 + d$prosoc_left + 2*d$condition
d %>% count(treatment, prosoc_left, condition)
  treatment prosoc_left condition   n
1         1           0         0 126
2         2           1         0 126
3         3           0         1 126
4         4           1         1 126

In this experiment, each pull is within a cluster of pulls belonging to an individual chimpanzee. But each pull is also within an experimental block, which represents a collection of observations that happened on the same day. So each observed pull belongs to both an actor (1 to 7) and a block (1 to 6). There may be unique intercepts for each actor as well as for each block.

Mathematical model:

\[\begin{align*} L_i &\sim \text{Binomial}(1, p_i) \\ \text{logit}(p_i) &= \bar{\alpha} + \alpha_{\text{ACTOR[i]}} + \bar{\gamma} + \gamma_{\text{BLOCK[i]}} + \beta_{\text{TREATMENT[i]}} \\ \beta_j &\sim \text{Normal}(0, 0.5) \text{ , for }j=1..4\\ \alpha_j &\sim \text{Normal}(0, \sigma_{\alpha}) \text{ , for }j=1..7\\ \gamma_j &\sim \text{Normal}(0, \sigma_{\gamma}) \text{ , for }j=1..7\\ \bar{\alpha} &\sim \text{Normal}(0, 1.5) \\ \bar{\gamma} &\sim \text{Normal}(0, 1.5) \\ \sigma_{\alpha} &\sim \text{Exponential}(1) \\ \sigma_{\gamma} &\sim \text{Exponential}(1) \\ \end{align*}\]

m3 <- 
  brm(
    family = bernoulli,
    data = d, 
    bf(
      pulled_left ~ a + b, 
      a ~ 1 + (1 | actor) + (1 | block), 
      b ~ 0 + treatment, 
      nl = TRUE),
    prior = c(prior(normal(0, 0.5), nlpar = b),
              prior(normal(0, 1.5), class = b, coef = Intercept, nlpar = a),
              prior(exponential(1), class = sd, group = actor, nlpar = a),
              prior(exponential(1), class = sd, group = block, nlpar = a)),
  chains=4, cores=4, iter=2000, warmup=1000,
  seed = 1,
  file = here("files/data/generated_data/m71.3")
  )
m3
 Family: bernoulli 
  Links: mu = logit 
Formula: pulled_left ~ a + b 
         a ~ 1 + (1 | actor) + (1 | block)
         b ~ 0 + treatment
   Data: d (Number of observations: 504) 
  Draws: 4 chains, each with iter = 2000; warmup = 1000; thin = 1;
         total post-warmup draws = 4000

Multilevel Hyperparameters:
~actor (Number of levels: 7) 
                Estimate Est.Error l-95% CI u-95% CI Rhat Bulk_ESS Tail_ESS
sd(a_Intercept)     2.00      0.65     1.06     3.54 1.00     1198     1608

~block (Number of levels: 6) 
                Estimate Est.Error l-95% CI u-95% CI Rhat Bulk_ESS Tail_ESS
sd(a_Intercept)     0.22      0.21     0.01     0.65 1.00     1131     1244

Regression Coefficients:
            Estimate Est.Error l-95% CI u-95% CI Rhat Bulk_ESS Tail_ESS
a_Intercept     0.53      0.73    -0.96     1.98 1.01      698     1116
b_treatment     0.05      0.09    -0.13     0.23 1.00     4076     2604

Draws were sampled using sampling(NUTS). For each parameter, Bulk_ESS
and Tail_ESS are effective sample size measures, and Rhat is the potential
scale reduction factor on split chains (at convergence, Rhat = 1).
posterior_summary(m3)
                             Estimate  Est.Error          Q2.5        Q97.5
b_a_Intercept              0.52503260 0.73111087 -9.581334e-01    1.9815720
b_b_treatment              0.04559243 0.09200118 -1.349792e-01    0.2267296
sd_actor__a_Intercept      1.99783341 0.65020002  1.061929e+00    3.5406747
sd_block__a_Intercept      0.21551277 0.20954345  6.135531e-03    0.6501890
r_actor__a[1,Intercept]   -0.96793617 0.73703910 -2.447011e+00    0.4459541
r_actor__a[2,Intercept]    3.94453862 1.31977776  1.935123e+00    6.9596606
r_actor__a[3,Intercept]   -1.26903212 0.73800364 -2.769589e+00    0.1311174
r_actor__a[4,Intercept]   -1.26255367 0.74027384 -2.773092e+00    0.2293022
r_actor__a[5,Intercept]   -0.96852948 0.73790662 -2.442937e+00    0.4910080
r_actor__a[6,Intercept]   -0.05696122 0.74248157 -1.513062e+00    1.4361641
r_actor__a[7,Intercept]    1.43231847 0.78057577 -1.308899e-01    3.0229796
r_block__a[1,Intercept]   -0.17039823 0.23004831 -7.519251e-01    0.1154111
r_block__a[2,Intercept]    0.05593816 0.19731669 -2.881960e-01    0.4848865
r_block__a[3,Intercept]    0.05547394 0.19462158 -2.930123e-01    0.5015258
r_block__a[4,Intercept]   -0.01175494 0.19160806 -4.162736e-01    0.3641352
r_block__a[5,Intercept]   -0.01234951 0.18989539 -4.148848e-01    0.3585132
r_block__a[6,Intercept]    0.10139751 0.20294348 -2.045939e-01    0.5720760
lprior                    -3.96463354 0.75902901 -5.782336e+00   -2.8652951
lp__                    -290.74334856 3.55784209 -2.984567e+02 -284.7255442
m3 %>% 
  mcmc_plot(variable = c("^r_", "^b_", "^sd_"), regex = T) +
  theme(axis.text.y = element_text(hjust = 0))
Code
as_draws_df(m3) %>% 
  select(starts_with("sd")) %>% 
  pivot_longer(everything()) %>% 
  ggplot(aes(x = value, fill = name)) +
  geom_density(linewidth = 0, alpha = 3/4, adjust = 2/3, show.legend = F) +
  annotate(geom = "text", x = 0.67, y = 2, label = "block", color = "#5e8485") +
  annotate(geom = "text", x = 2.725, y = 0.5, label = "actor", color = "#0f393a") +
  scale_fill_manual(values = c("#0f393a", "#5e8485")) +
  scale_y_continuous(NULL, breaks = NULL) +
  ggtitle(expression(sigma["group"])) +
  coord_cartesian(xlim = c(0, 4))